{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Train data path | 设置训练用模型、图片\n",
    "pretrained_model = \"./sd-models/model.ckpt\" # base model path | 底模路径\n",
    "train_data_dir = \"./train/aki\" # train dataset path | 训练数据集路径\n",
    "\n",
    "# Train related params | 训练相关参数\n",
    "resolution = \"512,512\" # image resolution w,h. 图片分辨率，宽,高。支持非正方形，但必须是 64 倍数。\n",
    "batch_size = 1 # batch size\n",
    "max_train_epoches = 10 # max train epoches | 最大训练 epoch\n",
    "save_every_n_epochs = 2 # save every n epochs | 每 N 个 epoch 保存一次\n",
    "network_dim = 32 # network dim | 常用 4~128，不是越大越好\n",
    "network_alpha= 32 # network alpha | 常用与 network_dim 相同的值或者采用较小的值，如 network_dim的一半 防止下溢。默认值为 1，使用较小的 alpha 需要提升学习率。\n",
    "clip_skip = 2 # clip skip | 玄学 一般用 2\n",
    "train_unet_only = 0 # train U-Net only | 仅训练 U-Net，开启这个会牺牲效果大幅减少显存使用。6G显存可以开启\n",
    "train_text_encoder_only = 0 # train Text Encoder only | 仅训练 文本编码器\n",
    "\n",
    "# Learning rate | 学习率\n",
    "lr = \"1e-4\"\n",
    "unet_lr = \"1e-4\"\n",
    "text_encoder_lr = \"1e-5\"\n",
    "lr_scheduler = \"cosine_with_restarts\" # \"linear\", \"cosine\", \"cosine_with_restarts\", \"polynomial\", \"constant\", \"constant_with_warmup\"\n",
    "\n",
    "# Output settings | 输出设置\n",
    "output_name = \"aki\" # output model name | 模型保存名称\n",
    "save_model_as = \"safetensors\" # model save ext | 模型保存格式 ckpt, pt, safetensors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "!accelerate launch --num_cpu_threads_per_process=8 \"./sd-scripts/train_network.py\" \\\n",
    "  --enable_bucket \\\n",
    "  --pretrained_model_name_or_path=$pretrained_model \\\n",
    "  --train_data_dir=$train_data_dir \\\n",
    "  --output_dir=\"./output\" \\\n",
    "  --logging_dir=\"./logs\" \\\n",
    "  --resolution=$resolution \\\n",
    "  --network_module=networks.lora \\\n",
    "  --max_train_epochs=$max_train_epoches \\\n",
    "  --learning_rate=$lr \\\n",
    "  --unet_lr=$unet_lr \\\n",
    "  --text_encoder_lr=$text_encoder_lr \\\n",
    "  --network_dim=$network_dim \\\n",
    "  --network_alpha=$network_alpha \\\n",
    "  --output_name=$output_name \\\n",
    "  --lr_scheduler=$lr_scheduler \\\n",
    "  --train_batch_size=$batch_size \\\n",
    "  --save_every_n_epochs=$save_every_n_epochs \\\n",
    "  --mixed_precision=\"fp16\" \\\n",
    "  --save_precision=\"fp16\" \\\n",
    "  --seed=\"1337\" \\\n",
    "  --cache_latents \\\n",
    "  --clip_skip=$clip_skip \\\n",
    "  --prior_loss_weight=1 \\\n",
    "  --max_token_length=225 \\\n",
    "  --caption_extension=\".txt\" \\\n",
    "  --save_model_as=$save_model_as \\\n",
    "  --xformers --shuffle_caption --use_8bit_adam"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.7 (tags/v3.10.7:6cc6b13, Sep  5 2022, 14:08:36) [MSC v.1933 64 bit (AMD64)]"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "675b13e958f0d0236d13cdfe08a1df3882cae564fa23a2e7e5eb1f2c6c632b02"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}